import numpy as np
import os
import csv
import json
from tokenizer import tokenize
import pickle as pkl
from pymongo import MongoClient
client = MongoClient(port=27017)
db=client['tiktok']
collist=['wip', 'haventseen', 'ootd', 'bekind', 'personalfinance', 'cozyathome', 'RoomTour', 'theatrekids', 'ImAGhost', 'holidaytiktok', 'halloweenlook', 'happyhalloween', 'welldone', 'motivationmonday', 'thinkingabout', 'nonuancenovember', 'ourtype', 'fanedit', 'needtoknow', 'cleantok', 'graphicdesign', 'readysetshop', 'holidaysourway', 'onlinedating', 'myhobby', 'tiktokfood', 'whereilive', 'myrecommendation', 'worldseries', 'animation', 'cocinando', 'easydiy', 'diceroll', 'rnbvibes', 'festivefashion', 'holidaydecor', 'nbadraft', 'halloweenishere', 'christmas2020', 'howbizarre', 'sfxmakeup', 'givingthanks', 'holidayvibes', 'gamenight', 'happyholidays', 'carsoftiktok', 'fallguysmoments', 'interiordesign', 'homecooked', 'veteransday', 'youwantmore', 'coldweather', 'wildanimals', 'mycostume', 'meleaving', 'mypfp', 'catchphrases', 'watchmegrow', 'holidaycrafts', 'growupwithme', 'clingypet', 'happyhanukkah', 'lunarnewyear', 'tabletop', 'comfortfood', 'selfimprovement', '2021affirmations', 'perfectmatch', 'givingszn', 'holidaycountdown', 'bakingszn', 'holidaymusic', 'familyimpression', 'inkdrawing', 'WeekendVibes', 'recordsday', 'productivity', 'smallbusiness', 'falldiy', 'whenwewereyounger', 'yellow', 'ComingOfAge', 'artmas', 'gaminglife', 'gamingsetup', 'hellowinter', 'planttiktok', 'housetour', 'neonshadow', 'homeoffice', 'raisedby', 'makeitvogue', 'foodtiktok', 'valentinesday', 'yougotthis', 'stemlife']
ids={}
embedding_matrix = np.load('E:\\data_pi\\embedding_matrix.npy')
embedding_matrix_norm = np.load('E:\\data_pi\\embedding_matrix_norm.npy')
word_index = pkl.load(open('E:\\data_pi\\word_index.pkl','rb'))
with open('D:\\Work\\kusuri\\newlyaddedids.tsv','r',newline='\n',encoding='utf-8')as fin:
    reader = csv.reader(fin, delimiter='\t')
    for line in reader:
        if line[0] not in ids.keys():
            ids[line[0]]=line[2:]
        else:
            ids[line[0]].extend(line[2:])

for ht in ids.keys():
    idlist = []
    img = {}
    yamnet = {}
    texts = {}
    texts_sticker = {}
    labels = {}
    edits = {}
    htids = {}
    img_embed = {}
    for cid in ids[ht]:
        for obj in db[ht].find({'_id': cid}):
            if len(obj['video_feature']['img_embed']) > 0 and len(obj['text_feature']['text']) > 0 and len(
                    obj['video_feature']['audio']['yamnet']) > 0 and (
                    'var_sb' in obj['video_feature']['editing'].keys()) and (
                    'avg_sticker_length' in obj['video_feature']['editing'].keys()) and (
                    'avg_scences' in obj['video_feature']['editing'].keys()) and (
                    'var_yamnet' in obj['video_feature']['editing'].keys()):
                idlist.append(obj['_id'])
                img_embed[obj['_id']] = []
                maximg = obj['video_feature']['img_embed']['0'][4096:]
                for key in obj['video_feature']['img_embed'].keys():
                    img_embed[obj['_id']].append(obj['video_feature']['img_embed'][key][:4096])
                    for i in range(4096, len(obj['video_feature']['img_embed'][key])):
                        if maximg[i - 4096] < obj['video_feature']['img_embed'][key][i]:
                            maximg[i - 4096] = obj['video_feature']['img_embed'][key][i]
                img[obj['_id']] = maximg
                yamnet[obj['_id']] = obj['video_feature']['audio']['yamnet']

                temptext = obj['text_feature']['text']
                temptext_sticker = ''
                for item in obj['text_feature']['stickerText']:
                    temptext_sticker += item + ' '
                texts[obj['_id']] = temptext
                texts_sticker[obj['_id']] = temptext_sticker
                # labels[obj['_id']]=[int(obj['video_feature']['label']['labelA']),int(obj['video_feature']['label']['labelB']),float(obj['video_feature']['residual']['residualA']),float(obj['video_feature']['residual']['residualB'])]

                htids[obj['_id']] = ht
                edits[obj['_id']] = [obj['video_feature']['editing']['video_len'],
                                     obj['video_feature']['editing']['var_sb'],
                                     obj['video_feature']['editing']['var_sb_c'],
                                     obj['video_feature']['editing']['var_vgg'],
                                     obj['video_feature']['editing']['var_vgg_c'],
                                     obj['video_feature']['editing']['var_yamnet'],
                                     obj['video_feature']['editing']['sticker_num'],
                                     obj['video_feature']['editing']['avg_sticker_length'],
                                     obj['video_feature']['editing']['avg_scences']]
    train_img_emb = []
    train_yamnet = []
    train_text = []
    train_text_embed = []
    train_text_embed_norm = []
    train_text_sticker = []
    train_text_embed_sticker = []
    train_text_embed_norm_sticker = []
    train_labels = []
    train_edits = []
    img_sequence_len = 13
    train_ids = []
    train_img_emb_new = []
    sequence_len = 20

    for item in idlist:
        train_ids.append(item)
        train_img_emb.append(img[item])
        train_yamnet.append(yamnet[item])
        words = tokenize(texts[item])
        text = []
        text_embed = []
        text_embed_norm = []
        for word in words[:sequence_len]:
            if word in word_index:
                text.append(word_index[word])
            else:
                continue
            text_embed.append(embedding_matrix[text[-1]])
            text_embed_norm.append(embedding_matrix_norm[text[-1]])
        while len(text) < sequence_len:
            text.append(0)
            text_embed.append(np.zeros(embedding_matrix.shape[1]))
            text_embed_norm.append(np.zeros(embedding_matrix_norm.shape[1]))

        words_sticker = tokenize(texts_sticker[item])
        text_sticker = []
        text_embed_sticker = []
        text_embed_norm_sticker = []
        for word_sticker in words_sticker[:sequence_len]:
            if word_sticker in word_index:
                text_sticker.append(word_index[word_sticker])
            else:
                continue
            text_embed_sticker.append(embedding_matrix[text_sticker[-1]])
            text_embed_norm_sticker.append(embedding_matrix_norm[text_sticker[-1]])
        while len(text_sticker) < sequence_len:
            text_sticker.append(0)
            text_embed_sticker.append(np.zeros(embedding_matrix.shape[1]))
            text_embed_norm_sticker.append(np.zeros(embedding_matrix_norm.shape[1]))

        image_embeds = img_embed[item]
        im = []
        for image in image_embeds[:img_sequence_len]:
            im.append(image)
        while len(im) < img_sequence_len:
            im.append(np.zeros(4096))
        train_img_emb_new.append(im)
        train_text.append(text)
        train_text_embed.append(text_embed)
        train_text_embed_norm.append(text_embed_norm)

        train_text_sticker.append(text_sticker)
        train_text_embed_sticker.append(text_embed_sticker)
        train_text_embed_norm_sticker.append(text_embed_norm_sticker)

        train_edits.append(edits[item])
    train_img_emb = np.array(train_img_emb).astype(np.float32)
    train_yament = np.array(train_yamnet).astype(np.float32)
    train_img_emb_new = np.array(train_img_emb_new).astype(np.float32)
    train_edits = np.array(train_edits).astype(np.float32)
    np.save('E:\\newlyadded\\image_embed_prob_newlyadded_'+ht, np.array(train_img_emb))
    np.save('E:\\newlyadded\\yamnet_embed_newlyadded_'+ht, np.array(train_yamnet))
    np.save('E:\\newlyadded\\text_newlyadded_'+ht, train_text)
    np.save('E:\\newlyadded\\edit_embed_newlyadded_'+ht, train_edits)
    np.save('E:\\newlyadded\\image_embed_newlyadded_'+ht, np.array(train_img_emb_new))
    np.save('E:\\newlyadded\\ids_newlyadded_'+ht, train_ids)

    np.save('E:\\newlyadded\\text_embed_newlyadded_'+ht, train_text_embed)
    np.save('E:\\newlyadded\\text_embed_norm_newlyadded_'+ht, train_text_embed_norm)

    np.save('E:\\newlyadded\\text_sticker_newlyadded_'+ht, train_text_sticker)
    np.save('E:\\newlyadded\\text_sticker_embed_newlyadded_'+ht, train_text_embed_sticker)
    np.save('E:\\newlyadded\\text_sticker_embed_norm_newlyadded_'+ht, train_text_embed_norm_sticker)